#
# SPDX-FileCopyrightText: Copyright 2024-2025 Arm Limited and/or its affiliates <open-source-office@arm.com>
#
# SPDX-License-Identifier: Apache-2.0
#

load(
    "//:kai_defs.bzl",
    "kai_c_library",
    "kai_cpu_bf16",
    "kai_cpu_dotprod",
    "kai_cpu_fp16",
    "kai_cpu_i8mm",
    "kai_cpu_neon",
    "kai_cpu_sme",
    "kai_cpu_sme2",
)

package(default_visibility = ["//visibility:private"])

# buildifier: keep sorted
SCALAR_KERNELS = [
    "pack/kai_lhs_quant_pack_qai8dxp_f32",
    "pack/kai_lhs_quant_pack_qsi8d32p_f32",
    "pack/kai_rhs_pack_kxn_qsi4cxp_qs4cxs1s0",
    "pack/kai_rhs_pack_nxk_qsi4cxp_qs4cxs1s0",
]

# buildifier: keep sorted
NEON_KERNELS = [
    "pack/kai_lhs_quant_pack_qsi8d32p_f32_neon",
    "pack/kai_rhs_pack_kxn_f32p8x1biasf32_f32_f32_neon",
    "pack/kai_rhs_pack_kxn_qsi4c32p_qsu4c32s1s0",
    "pack/kai_rhs_pack_kxn_qsi8cxp_qsi8cx_neon",
    "pack/kai_rhs_pack_nxk_qsi4c32p_qsu4c32s1s0",
    "pack/kai_rhs_pack_nxk_qsi4c32ps1s0scalef16_qsu4c32s16s0_neon",
    "pack/kai_rhs_pack_nxk_qsi4c32pscalef16_qsu4c32s16s0",
    "pack/kai_rhs_pack_nxk_qsi4cxps1s0_qsu4cxs1s0_neon",
    "pack/kai_rhs_pack_nxk_qsi8cxp_qsi8cx_neon",
]

# buildifier: keep sorted
NEON_KERNELS_ASM = [
    "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p8x1biasf32_6x8x4_neon_mla",
]

# buildifier: keep sorted
FP16_KERNELS = [
    "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p16x1biasf16_6x16x8_neon_mla",
    "pack/kai_rhs_pack_kxn_f16p16x1biasf16_f16_f16_neon",
]

# buildifier: keep sorted
BF16_KERNELS = [
    "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p1x4_bf16p12x4b_1x36_neon_dot",
    "matmul_clamp_f32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p8x4_bf16p12x4b_8x12_neon_mmla",
    "pack/kai_lhs_quant_pack_bf16p1x4_f32_neon",
    "pack/kai_lhs_quant_pack_bf16p8x4_f32_neon",
    "pack/kai_rhs_quant_pack_kxn_bf16p12x4biasf32_f32_neon",
]

# buildifier: keep sorted
FP16_BF16_KERNELS = [
    "matmul_clamp_f16_bf16p_bf16p/kai_matmul_clamp_f16_bf16p8x4_bf16p12x4b_8x12_neon_mmla",
    "pack/kai_lhs_pack_bf16p8x4_f16_neon",
    "pack/kai_rhs_pack_kxn_bf16p12x4biasf16_f16_neon",
    "pack/kai_rhs_pack_kxn_bf16p12x4biasf32_f16_neon",
]

# buildifier: keep sorted
DOTPROD_KERNELS = [
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4x4_1x4_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp4x8_1x4x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi4cxp8x8_1x8x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi4cxp8x4_8x8x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x4_16x4x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi8cxp4x4_1x4_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp1x8_qsi8cxp4x8_1x4_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x4_qsi8cxp4x4_16x4_neon_dotprod",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4x4_1x4_neon_dotprod",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x8_qsi4c32p4x8_1x4x32_neon_dotprod",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x4_qsi4c32p4x4_16x4_neon_dotprod",
]

# buildifier: keep sorted
DOTPROD_KERNELS_ASM = [
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p4x4_1x4_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x4_qsi4c32p8x4_1x8_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p4x8_1x4x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp1x8_qsi4c32p8x8_1x8x32_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p4x4_16x4_neon_dotprod",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x4_qsi4c32p8x4_4x8_neon_dotprod",
]

# buildifier: keep sorted
I8MM_KERNELS = [
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_4x4x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp4x8_8x4x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_4x8x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi4cxp8x8_8x8x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi8cxp/kai_matmul_clamp_f32_qai8dxp4x8_qsi8cxp4x8_16x4_neon_i8mm",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_16x4_neon_i8mm",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p4x8_qsi4c32p4x8_8x4x32_neon_i8mm",
]

# buildifier: keep sorted
I8MM_KERNELS_ASM = [
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_16x4x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p4x8_8x4x32_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8_neon_i8mm",
    "matmul_clamp_f32_qai8dxp_qsi4c32p/kai_matmul_clamp_f32_qai8dxp4x8_qsi4c32p8x8_4x8x32_neon_i8mm",
]

# buildifier: keep sorted
SME_KERNELS = [
    "pack/kai_lhs_pack_bf16p2vlx2_f32_sme",
    "pack/kai_lhs_pack_f32p2vlx1_f32_sme",
    "pack/kai_lhs_pack_x16p2vlx2_x16_sme",
    "pack/kai_lhs_pack_x8p2vlx4_x8_sme",
    "pack/kai_rhs_pack_kxn_bf16p2vlx2b_f32_x32_sme",
    "pack/kai_rhs_pack_kxn_f32p16vlx1b_f32_f32_sme",
    "pack/kai_rhs_pack_kxn_f32p2vlx1biasf32_f32_f32_sme",
    "pack/kai_rhs_pack_kxn_qsi8cxp2vlx4sb_qs8cx_f32_i32_sme",
    "pack/kai_rhs_pack_kxn_x16p2vlx2b_x16_x16_sme",
    "pack/kai_rhs_pack_nxk_f32p2vlx1biasf32_f32_f32_sme",
    "pack/kai_rhs_pack_nxk_x16p2vlx2b_x16_x16_sme",
]

# buildifier: keep sorted
SME2_KERNELS = [
    "matmul_clamp_f16_f16_f16p/kai_matmul_clamp_f16_f16_f16p2vlx2b_1x16vl_sme2_dot",
    "matmul_clamp_f16_f16p_f16p/kai_matmul_clamp_f16_f16p2vlx2_f16p2vlx2_2vlx2vl_sme2_mopa",
    "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p16vlx1b_1x16vl_sme2_mla",
    "matmul_clamp_f32_f32_f32p/kai_matmul_clamp_f32_f32_f32p2vlx1b_1x16vl_sme2_mla",
    "matmul_clamp_f32_f32p_f32p/kai_matmul_clamp_f32_f32p2vlx1_f32p2vlx1biasf32_sme2_mopa",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1vlx8_qsi4cxp4vlx8_1vlx4vl_sme2_mopa",
    "matmul_clamp_f32_qai8dxp_qsi4cxp/kai_matmul_clamp_f32_qai8dxp1x4_qsi4cxp4vlx4_1x4vl_sme2_sdot",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1vlx4_qsi4c32p4vlx4_1vlx4vl_sme2_mopa",
    "matmul_clamp_f32_qsi8d32p_qsi4c32p/kai_matmul_clamp_f32_qsi8d32p1x4_qsi4c32p4vlx4_1x4vl_sme2_sdot",
    "matmul_clamp_fp32_bf16p_bf16p/kai_matmul_clamp_f32_bf16p2vlx2_bf16p2vlx2_2vlx2vl_sme2_mopa",
    "matmul_clamp_qai8_qai8_qsi8cxp/kai_matmul_clamp_qai8_qai8_qsi8cxp2vlx4sb_1x16vl_sme2_dot",
    "matmul_clamp_qai8_qai8p_qsi8cxp/kai_matmul_clamp_qai8_qai8p2vlx4_qsi8cxpsb2vlx4_2vlx2vl_sme2_mopa",
]

kai_c_library(
    name = "interface",
    textual_hdrs = glob(["**/*_interface.h"]),
    visibility = ["//visibility:public"],
)

kai_c_library(
    name = "scalar_impl",
    srcs = [ukernel + ".c" for ukernel in SCALAR_KERNELS],
    textual_hdrs = [ukernel + ".h" for ukernel in SCALAR_KERNELS],
)

kai_c_library(
    name = "neon_impl",
    srcs = [ukernel + ".c" for ukernel in NEON_KERNELS],
    cpu_uarch = kai_cpu_neon(),
    textual_hdrs = [ukernel + ".h" for ukernel in NEON_KERNELS],
)

kai_c_library(
    name = "neon_impl_asm",
    srcs = [ukernel + "_asm.S" for ukernel in NEON_KERNELS_ASM] + [ukernel + ".c" for ukernel in NEON_KERNELS_ASM],
    cpu_uarch = kai_cpu_neon(),
    textual_hdrs = [ukernel + ".h" for ukernel in NEON_KERNELS_ASM],
)

kai_c_library(
    name = "fp16_impl",
    srcs = [ukernel + ".c" for ukernel in FP16_KERNELS],
    cpu_uarch = kai_cpu_fp16(),
    textual_hdrs = [ukernel + ".h" for ukernel in FP16_KERNELS],
)

kai_c_library(
    name = "bf16_impl",
    srcs = [ukernel + ".c" for ukernel in BF16_KERNELS],
    cpu_uarch = kai_cpu_bf16(),
    textual_hdrs = [ukernel + ".h" for ukernel in BF16_KERNELS],
)

kai_c_library(
    name = "fp16_bf16_impl",
    srcs = [ukernel + ".c" for ukernel in FP16_BF16_KERNELS],
    cpu_uarch = kai_cpu_fp16() + kai_cpu_bf16(),
    textual_hdrs = [ukernel + ".h" for ukernel in FP16_BF16_KERNELS],
)

kai_c_library(
    name = "dotprod_impl",
    srcs = [ukernel + ".c" for ukernel in DOTPROD_KERNELS],
    cpu_uarch = kai_cpu_dotprod(),
    textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_KERNELS],
)

kai_c_library(
    name = "dotprod_impl_asm",
    srcs = [ukernel + "_asm.S" for ukernel in DOTPROD_KERNELS_ASM] + [ukernel + ".c" for ukernel in DOTPROD_KERNELS_ASM],
    cpu_uarch = kai_cpu_dotprod(),
    textual_hdrs = [ukernel + ".h" for ukernel in DOTPROD_KERNELS_ASM],
)

kai_c_library(
    name = "i8mm_impl",
    srcs = [ukernel + ".c" for ukernel in I8MM_KERNELS],
    cpu_uarch = kai_cpu_i8mm(),
    textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS],
)

kai_c_library(
    name = "i8mm_impl_asm",
    srcs = [ukernel + "_asm.S" for ukernel in I8MM_KERNELS_ASM] + [ukernel + ".c" for ukernel in I8MM_KERNELS_ASM],
    cpu_uarch = kai_cpu_i8mm(),
    textual_hdrs = [ukernel + ".h" for ukernel in I8MM_KERNELS_ASM],
)

kai_c_library(
    name = "sme_impl",
    srcs = [ukernel + ".c" for ukernel in SME_KERNELS],
    cpu_uarch = kai_cpu_sme(),
    textual_hdrs = [ukernel + ".h" for ukernel in SME_KERNELS],
)

kai_c_library(
    name = "sme2_impl",
    srcs = [ukernel + ".c" for ukernel in SME2_KERNELS],
    cpu_uarch = kai_cpu_sme2(),
    textual_hdrs = [ukernel + ".h" for ukernel in SME2_KERNELS],
)

kai_c_library(
    name = "matmul",
    visibility = ["//visibility:public"],
    deps = [
        ":bf16_impl",
        ":dotprod_impl",
        ":dotprod_impl_asm",
        ":fp16_bf16_impl",
        ":fp16_impl",
        ":i8mm_impl",
        ":i8mm_impl_asm",
        ":interface",
        ":neon_impl",
        ":neon_impl_asm",
        ":scalar_impl",
        ":sme2_impl",
        ":sme_impl",
    ],
)
